import pytest
import torch
from ding.framework import OnlineRLContext
from ding.data.buffer import DequeBuffer
from typing import Any
import numpy as np
import copy
from ding.framework.middleware.functional.enhancer import reward_estimator, her_data_enhancer
from unittest.mock import Mock, patch
from ding.framework.middleware.tests import MockHerRewardModel, CONFIG

DATA = [{'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2)} for _ in range(20)]


class MockRewardModel(Mock):

    def estimate(self, data: list) -> Any:
        assert len(data) == len(DATA)
        assert torch.equal(data[0]['obs'], DATA[0]['obs'])


@pytest.mark.unittest
def test_reward_estimator():
    ctx = OnlineRLContext()
    ctx.train_data = copy.deepcopy(DATA)
    with patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
        reward_estimator(cfg=None, reward_model=MockRewardModel())(ctx)


@pytest.mark.unittest
def test_her_data_enhancer():
    cfg = copy.deepcopy(CONFIG)
    ctx = OnlineRLContext()

    with patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
        mock_her_reward_model = MockHerRewardModel()
        buffer = DequeBuffer(mock_her_reward_model.episode_size)

        train_data = [
            [
                {
                    'action': torch.randint(low=0, high=5, size=(1, )),
                    'collect_train_iter': torch.tensor([0]),
                    'done': torch.tensor(False),
                    'next_obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32),
                    'obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32),
                    'reward': torch.randint(low=0, high=2, size=(1, ), dtype=torch.float32),
                } for _ in range(np.random.choice([1, 4, 5], size=1)[0])
            ] for _ in range(mock_her_reward_model.episode_size)
        ]

        for d in train_data:
            buffer.push(d)

        her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx)
        assert len(ctx.train_data) == mock_her_reward_model.episode_size * mock_her_reward_model.episode_element_size
        assert len(ctx.train_data[0]) == 6

        buffer = DequeBuffer(cfg.policy.learn.batch_size)
        for d in train_data:
            buffer.push(d)
        mock_her_reward_model.episode_size = None
        her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx)
        assert len(ctx.train_data) == cfg.policy.learn.batch_size * mock_her_reward_model.episode_element_size
        assert len(ctx.train_data[0]) == 6
